{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# REINFORCE in TensorFlow (3 pts)¶\n",
    "\n",
    "This notebook implements a basic reinforce algorithm a.k.a. policy gradient for CartPole env.\n",
    "\n",
    "It has been deliberately written to be as simple and human-readable.\n",
    "\n",
    "Authors: [Practical_RL](https://github.com/yandexdataschool/Practical_RL) course team"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
    "    !bash ../xvfb start\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The notebook assumes that you have [openai gym](https://github.com/openai/gym) installed.\n",
    "\n",
    "In case you're running on a server, [use xvfb](https://github.com/openai/gym#rendering-on-a-server)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "env = gym.make(\"CartPole-v0\")\n",
    "\n",
    "# gym compatibility: unwrap TimeLimit\n",
    "if hasattr(env, '_max_episode_steps'):\n",
    "    env = env.env\n",
    "\n",
    "env.reset()\n",
    "n_actions = env.action_space.n\n",
    "state_dim = env.observation_space.shape\n",
    "\n",
    "plt.imshow(env.render(\"rgb_array\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building the network for REINFORCE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For REINFORCE algorithm, we'll need a model that predicts action probabilities given states.\n",
    "\n",
    "For numerical stability, please __do not include the softmax layer into your network architecture__.\n",
    "We'll use softmax or log-softmax where appropriate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# create input variables. We only need <s,a,R> for REINFORCE\n",
    "states = tf.placeholder('float32', (None,)+state_dim, name=\"states\")\n",
    "actions = tf.placeholder('int32', name=\"action_ids\")\n",
    "cumulative_rewards = tf.placeholder('float32', name=\"cumulative_returns\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "<define network graph using raw tf or any deep learning library >\n",
    "\n",
    "\n",
    "logits = <linear outputs(symbolic) of your network >\n",
    "\n",
    "policy = tf.nn.softmax(logits)\n",
    "log_policy = tf.nn.log_softmax(logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# utility function to pick action in one given state\n",
    "def get_action_proba(s): return policy.eval({states: [s]})[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Loss function and updates\n",
    "\n",
    "We now need to define objective and update over policy gradient.\n",
    "\n",
    "Our objective function is\n",
    "\n",
    "$$ J \\approx  { 1 \\over N } \\sum  _{s_i,a_i} \\pi_\\theta (a_i | s_i) \\cdot G(s_i,a_i) $$\n",
    "\n",
    "\n",
    "Following the REINFORCE algorithm, we can define our objective as follows: \n",
    "\n",
    "$$ \\hat J \\approx { 1 \\over N } \\sum  _{s_i,a_i} log \\pi_\\theta (a_i | s_i) \\cdot G(s_i,a_i) $$\n",
    "\n",
    "When you compute gradient of that function over network weights $ \\theta $, it will become exactly the policy gradient.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get probabilities for parti\n",
    "indices = tf.stack([tf.range(tf.shape(log_policy)[0]), actions], axis=-1)\n",
    "log_policy_for_actions = tf.gather_nd(log_policy, indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# policy objective as in the last formula. please use mean, not sum.\n",
    "# note: you need to use log_policy_for_actions to get log probabilities for actions taken.\n",
    "\n",
    "J = <YOUR CODE>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# regularize with entropy\n",
    "entropy = <compute entropy. Don't forget the sign!>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# all network weights\n",
    "all_weights = <a list of all trainable weights in your network >\n",
    "\n",
    "# weight updates. maximizing J is same as minimizing -J. Adding negative entropy.\n",
    "loss = -J - 0.1*entropy\n",
    "\n",
    "update = tf.train.AdamOptimizer().minimize(loss, var_list=all_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing cumulative rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def get_cumulative_rewards(rewards,  # rewards at each step\n",
    "                           gamma=0.99  # discount for reward\n",
    "                           ):\n",
    "    \"\"\"\n",
    "    take a list of immediate rewards r(s,a) for the whole session \n",
    "    compute cumulative rewards R(s,a) (a.k.a. G(s,a) in Sutton '16)\n",
    "    R_t = r_t + gamma*r_{t+1} + gamma^2*r_{t+2} + ...\n",
    "\n",
    "    The simple way to compute cumulative rewards is to iterate from last to first time tick\n",
    "    and compute R_t = r_t + gamma*R_{t+1} recurrently\n",
    "\n",
    "    You must return an array/list of cumulative rewards with as many elements as in the initial rewards.\n",
    "    \"\"\"\n",
    "\n",
    "    <your code here >\n",
    "\n",
    "    return < array of cumulative rewards >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(get_cumulative_rewards(range(100))) == 100\n",
    "assert np.allclose(get_cumulative_rewards([0, 0, 1, 0, 0, 1, 0], gamma=0.9), [\n",
    "                   1.40049, 1.5561, 1.729, 0.81, 0.9, 1.0, 0.0])\n",
    "assert np.allclose(get_cumulative_rewards(\n",
    "    [0, 0, 1, -2, 3, -4, 0], gamma=0.5), [0.0625, 0.125, 0.25, -1.5, 1.0, -4.0, 0.0])\n",
    "assert np.allclose(get_cumulative_rewards(\n",
    "    [0, 0, 1, 2, 3, 4, 0], gamma=0), [0, 0, 1, 2, 3, 4, 0])\n",
    "print(\"looks good!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_step(_states, _actions, _rewards):\n",
    "    \"\"\"given full session, trains agent with policy gradient\"\"\"\n",
    "    _cumulative_rewards = get_cumulative_rewards(_rewards)\n",
    "    update.run({states: _states, actions: _actions,\n",
    "                cumulative_rewards: _cumulative_rewards})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Playing the game"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_session(env, t_max=1000):\n",
    "    \"\"\"play env with REINFORCE agent and train at the session end\"\"\"\n",
    "\n",
    "    # arrays to record session\n",
    "    states, actions, rewards = [], [], []\n",
    "\n",
    "    s = env.reset()\n",
    "\n",
    "    for t in range(t_max):\n",
    "\n",
    "        # action probabilities array aka pi(a|s)\n",
    "        action_probas = get_action_proba(s)\n",
    "\n",
    "        a = <pick random action using action_probas >\n",
    "\n",
    "        new_s, r, done, info = env.step(a)\n",
    "\n",
    "        # record session history to train later\n",
    "        states.append(s)\n",
    "        actions.append(a)\n",
    "        rewards.append(r)\n",
    "\n",
    "        s = new_s\n",
    "        if done:\n",
    "            break\n",
    "\n",
    "    train_step(states, actions, rewards)\n",
    "\n",
    "    return sum(rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s = tf.InteractiveSession()\n",
    "s.run(tf.global_variables_initializer())\n",
    "\n",
    "for i in range(100):\n",
    "\n",
    "    rewards = [generate_session(env) for _ in range(100)]  # generate new sessions\n",
    "\n",
    "    print(\"mean reward:%.3f\" % (np.mean(rewards)))\n",
    "\n",
    "    if np.mean(rewards) > 300:\n",
    "        print(\"You Win!\")\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Results & video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# record sessions\n",
    "import gym.wrappers\n",
    "monitor_env = gym.wrappers.Monitor(gym.make(\"CartPole-v0\"), directory=\"videos\", force=True)\n",
    "sessions = [generate_session(monitor_env) for _ in range(100)]\n",
    "monitor_env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show video\n",
    "from IPython.display import HTML\n",
    "import os\n",
    "\n",
    "video_names = list(\n",
    "    filter(lambda s: s.endswith(\".mp4\"), os.listdir(\"./videos/\")))\n",
    "\n",
    "HTML(\"\"\"\n",
    "<video width=\"640\" height=\"480\" controls>\n",
    "  <source src=\"{}\" type=\"video/mp4\">\n",
    "</video>\n",
    "\"\"\".format(\"./videos/\"+video_names[-1]))  # this may or may not be _last_ video. Try other indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# That's all, thank you for your attention!"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
